import numpy as np
import yaml
from tqdm import tqdm
import pandas as pd
import argparse


class SimpleHiddenMarkovModel:
    @staticmethod
    def define_transitions(states, transitions):
        transition_matrix = np.zeros((len(states), len(states)))  # Use a NumPy array for transition matrix
        state_index = {state: idx for idx, state in enumerate(states)}
        for transition in transitions:
            from_state = transition['from']
            to_state = transition['to']
            prob = transition['probability']
            transition_matrix[state_index[from_state], state_index[to_state]] = prob
        return transition_matrix

    def __init__(self, config):
        self.states = config['states']
        self.state_index = {state: idx for idx, state in enumerate(self.states)}
        self.transition_probs = np.log(self.define_transitions(self.states, config['transitions']))  # Log of transition probabilities
        self.emission_probs = {state: np.log(np.array(emissions)) for state, emissions in config['emissions'].items()}  # Log of emission probabilities
        self.initial_probs = np.log(np.array(config['initial_probs']))  # Log of initial probabilities

    def generate_sequence(self, length):
        current_index = np.random.choice(len(self.states), p=np.exp(self.initial_probs))
        sequence = ''

        for _ in range(length):
            nucleotide = np.random.choice(['A', 'T', 'G', 'C'], p=np.exp(self.emission_probs[self.states[current_index]]))
            sequence += nucleotide
            current_index = np.random.choice(len(self.states), p=np.exp(self.transition_probs[current_index]))

        return sequence

    def rest_of_subsequence(self, valid_start_states, length):
        current_index = valid_start_states
        sequence = ''
        for _ in range(length):
            nucleotide = np.random.choice(['A', 'T', 'G', 'C'], p=np.exp(self.emission_probs[self.states[current_index]]))
            sequence += nucleotide
            current_index = np.random.choice(len(self.states), p=np.exp(self.transition_probs[current_index]))
        return sequence


    def calculate_misalignment(self, sequence):
        """
        Calculate the number of misalignments in the given sequence assuming
        each state has one-hot emission probabilities.
        """
        map_token_to_state_index = {}
        possible_tokens = set()

        # Determine the token emitted by each state with probability 1
        for state, log_probs in self.emission_probs.items():
            emitted_token = np.argmax(np.exp(log_probs))  # Token with the highest probability (one-hot)
            token = ['A', 'T', 'G', 'C'][emitted_token]
            map_token_to_state_index[token] = self.state_index[state]
            possible_tokens.add(token)

        possible_start_token = set()
        # Identify valid start tokens based on initial state probabilities
        for state, prob in zip(self.states, np.exp(self.initial_probs)):
            if prob > 0:
                possible_start_token.add(['A', 'T', 'G', 'C'][np.argmax(np.exp(self.emission_probs[state]))])

        count_wrong_start_token = 0
        count_wrong_transition = 0
        valid_start_states = None
        start_idx = 0

        # Find first valid start token
        for i in range(len(sequence)):
            cur_token = sequence[i]
            if cur_token in possible_start_token:
                valid_start_states = map_token_to_state_index[cur_token]
                start_idx = i
                break
            count_wrong_start_token += 1
        if valid_start_states is None:
            print("No valid start token found")
            return len(sequence), 0, len(sequence)
        # Generate the expected sequence from this start state
        rest_of_sequence = list(self.rest_of_subsequence(valid_start_states, len(sequence) - start_idx))
        # Compare each token in the actual sequence with the expected sequence
        for idx, (actual, expected) in enumerate(zip(sequence[start_idx:], rest_of_sequence)):
            if actual != expected:
                count_wrong_transition += 1
                # update rest of sequence if actual is wrong and actual is in possible tokens
                # if actual in possible_tokens:
                #     # Update the rest of the expected sequence based on the new actual token
                #     new_start_states = actual
                #     rest_of_sequence[idx:] = list(self.rest_of_subsequence(map_token_to_state_index[new_start_states], len(sequence) - start_idx - idx))

        return count_wrong_start_token, count_wrong_transition, count_wrong_start_token + count_wrong_transition


    def calculate_likelihood(self, sequence):
        """ Calculate the log likelihood of a given sequence """
        L = len(sequence)
        num_states = len(self.states)
        # Initialize the forward probabilities matrix
        F = np.full((L, num_states), -np.inf)  # Use -inf for log(0)

        # Initialize the first row of F
        for i in range(num_states):
            F[0, i] = self.initial_probs[i] + self.emission_probs[self.states[i]][self.nucleotide_index(sequence[0])]

        # Fill the forward matrix
        for l in range(1, L):
            for i in range(num_states):
                F[l, i] = np.logaddexp.reduce(F[l-1, :] + self.transition_probs[:, i]) + self.emission_probs[self.states[i]][self.nucleotide_index(sequence[l])]

        # Sum the last row to get the total log likelihood of the sequence
        return np.logaddexp.reduce(F[-1, :])

    def nucleotide_index(self, nucleotide):
        """ Map nucleotide to index """
        return {'A': 0, 'T': 1, 'G': 2, 'C': 3}[nucleotide]

    def calculate_perplexity(self, sequence):
        """ Calculate the perplexity of a given sequence """
        log_likelihood = self.calculate_likelihood(sequence)
        L = len(sequence)
        perplexity = np.exp(-log_likelihood / L)
        return perplexity

class GeneralHiddenMarkovModel:
    def __init__(self, config_file):
        with open(config_file, 'r') as file:
            config = yaml.safe_load(file)
        segment_names = config['segment_lengths']['names']
        segment_configs_mapping = {}
        for segment in config['segments']:
            segment_configs_mapping[segment['id']] = segment

        self.segment_lengths = config['segment_lengths']['lengths']
        self.segment_models = [SimpleHiddenMarkovModel(segment_configs_mapping[segment]) for segment in segment_names]
        # self.segment_models = [SimpleHiddenMarkovModel(segment) for segment in config['segments']]
        # self.segment_lengths = config['segment_lengths']


    def generate_sequence(self, total_length):
        lengths = (total_length * np.array(self.segment_lengths)).astype(int)
        lengths[-1] = total_length - sum(lengths[:-1])
        sequence = ''

        for segment_idx in range(len(self.segment_lengths)):
            model = self.segment_models[segment_idx]
            length = lengths[segment_idx]
            sequence += model.generate_sequence(length)

        return sequence

    def calculate_likelihood(self, sequence, total_length):
        """ Calculate the likelihood of the given sequence using segment models """
        lengths = (total_length * np.array(self.segment_lengths)).astype(int)
        lengths[-1] = total_length - sum(lengths[:-1])

        likelihood = 0.0
        likelihood_for_segments = []
        start = 0
        for segment_idx, length in enumerate(lengths):
            segment_sequence = sequence[start:start + length]
            likelihood_for_segment = self.segment_models[segment_idx].calculate_likelihood(segment_sequence)
            likelihood_for_segments.append(likelihood_for_segment)
            likelihood += likelihood_for_segment
            start += length

        # convert likelihood to log likelihood
        return likelihood, likelihood_for_segments

    def calculate_perplexity(self, sequence, total_length):
        lengths = (total_length * np.array(self.segment_lengths)).astype(int)
        lengths[-1] = total_length - sum(lengths[:-1])

        segment_perplexities = []
        start = 0
        for segment_idx, length in enumerate(lengths):
            segment_sequence = sequence[start:start + length]
            segment_perplexity = self.segment_models[segment_idx].calculate_perplexity(segment_sequence)
            segment_perplexities.append(segment_perplexity)
            start += length

        # Compute geometric mean of perplexities for a combined measure
        overall_perplexity = np.exp(np.mean(np.log(segment_perplexities)))
        return overall_perplexity, segment_perplexities

    def calculate_misalignment(self, sequence, total_length):
        lengths = (total_length * np.array(self.segment_lengths)).astype(int)
        lengths[-1] = total_length - sum(lengths[:-1])

        count_wrong_start_token = []
        count_wrong_transition = []
        count_wrong_total = []
        start = 0
        for segment_idx, length in enumerate(lengths):
            segment_sequence = sequence[start:start + length]
            count_start, count_transition, count_total = self.segment_models[segment_idx].calculate_misalignment(segment_sequence)
            count_wrong_start_token.append(count_start)
            count_wrong_transition.append(count_transition)
            count_wrong_total.append(count_total)
            start += length
        return count_wrong_start_token, count_wrong_transition, count_wrong_total

    def calculate_misalignment_multiple(self, sequences, total_length):
        count_wrong_start_token = []
        count_wrong_transition = []
        count_wrong_total = []
        for sequence in tqdm(sequences):
            count_start, count_transition, count_total = self.calculate_misalignment(sequence, total_length)
            count_wrong_start_token.append(count_start)
            count_wrong_transition.append(count_transition)
            count_wrong_total.append(count_total)
        # accuracte the percentage of wrong start token, wrong transition, and wrong start token + wrong transition
        # it should be the count divided by the total number of elemments = sequence length * number of sequences
        # sum along segments to get the total number of wrong tokens in all segments
        sum_wrong_start_token = np.sum(count_wrong_start_token, axis=0)
        sum_wrong_transition = np.sum(count_wrong_transition, axis=0)
        sum_wrong_total = np.sum(count_wrong_total, axis=0)
        return sum_wrong_start_token, sum_wrong_transition, sum_wrong_total


def save_to_csv(file_path, data):
    # Create a DataFrame with two columns: species and Sequence
    df = pd.DataFrame({
        'species': ['Apis mellifera (Honey bee).'] * len(data),  # Repeat the species name for each sequence
        'Sequence': data
    })

    # Write the DataFrame to a CSV file
    df.to_csv(file_path, index=False)

def load_fasta(file_path):
    sequences = []
    with open(file_path, 'r') as file:
        sequence = ''
        for line in file:
            if line.startswith('>'):
                sequences.append(sequence)
                sequence = ''
            else:
                sequence += line.strip()
        sequences.append(sequence)
    return sequences[1:]

def generate_sequences(seq_length, config_file, num_sequences=50000):
    general_hmm = GeneralHiddenMarkovModel(config_file)
    general_dna_sequence = []
    for _ in tqdm(range(num_sequences)):
        general_dna_sequence.append(general_hmm.generate_sequence(seq_length))
    return general_dna_sequence

def compute_misalignment(sequences, seq_length, config_file):
    general_hmm = GeneralHiddenMarkovModel(config_file)
    wrong_start, wrong_transition, wrong_total = general_hmm.calculate_misalignment_multiple(sequences, seq_length)

    print(f'Num of wrong start token: {wrong_start}')
    print(f'Num of wrong transition: {wrong_transition}')
    print(f'Num of wrong start token + wrong transition: {wrong_total}')
    print(f"Percentage of wrong start token: {np.sum(wrong_start) / (seq_length * len(sequences))}")
    print(f"Percentage of wrong transition: {np.sum(wrong_transition) / (seq_length * len(sequences))}")
    print(f"Percentage of wrong start token + wrong transition: {np.sum(wrong_total) / (seq_length * len(sequences))}")

def compute_likelihoods(sequences, seq_length, config_file):
    general_hmm = GeneralHiddenMarkovModel(config_file)

    perplexities = []
    segment_perplexities = []
    likelihoods = []
    segment_likelihoods = []
    for sequence in tqdm(sequences):
        cur_perp, cur_perp_seg = general_hmm.calculate_perplexity(sequence, seq_length)
        cur_lik, cur_lik_seg = general_hmm.calculate_likelihood(sequence, seq_length)
        perplexities.append(cur_perp)
        segment_perplexities.append(cur_perp_seg)
        likelihoods.append(cur_lik)
        segment_likelihoods.append(cur_lik_seg)

    avg_perplexity = np.mean(perplexities)
    avg_seg_perplexity = np.mean(segment_perplexities, axis=0)
    avg_likelihood = np.mean(likelihoods)
    avg_seg_likelihood = np.mean(segment_likelihoods, axis=0)
    print(f'Average Perplexity: {avg_perplexity}')
    print(f'Average Segment Perplexity: {avg_seg_perplexity}')

    print(f'Average Likelihood: {avg_likelihood}')
    print(f'Average Segment Likelihood: {avg_seg_likelihood}')

def main(args):
    if args.mode == 'generate':
        sequences = generate_sequences(args.seq_length, args.config_file, args.num_sequences)
        # get the config file name
        config_file_name = args.config_file.split('/')[-1].split('.')[0]
        save_to_csv(f'markov/generated_sequences_markov/{config_file_name}.csv', sequences)
        sequences = sequences[:4000]
    elif args.mode == 'compute':
        sequences = load_fasta(args.file_path)
    compute_likelihoods(sequences, args.seq_length, args.config_file)
    compute_misalignment(sequences, args.seq_length, args.config_file)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate sequences or compute likelihoods")
    parser.add_argument("mode", choices=['generate', 'compute'], help="Mode of operation: 'generate' to generate sequences, 'compute' to compute likelihoods")
    parser.add_argument("seq_length", type=int, help="Length of the sequences")
    parser.add_argument("--config_file", type=str, default='markov/config/test_model.yaml', help="Path to the config file")
    # define optional arguments
    parser.add_argument("--num_sequences", type=int, default=50000, help="Number of sequences to generate")
    parser.add_argument("--file_path", type=str, default='v2', help="File path to save the generated sequences")
    args = parser.parse_args()
    main(args)
